Skip to content

Add ExportFriendlyMultiheadAttention for dynamic shape torch.export#2

Closed
rbavery wants to merge 3 commits intomainfrom
ryan/export-friendly-mha
Closed

Add ExportFriendlyMultiheadAttention for dynamic shape torch.export#2
rbavery wants to merge 3 commits intomainfrom
ryan/export-friendly-mha

Conversation

@rbavery
Copy link
Member

@rbavery rbavery commented Jan 22, 2026

Summary

This PR adds ExportFriendlyMultiheadAttention, a custom MultiheadAttention implementation that enables torch.export with dynamic shapes (e.g., variable image H/W).

Problem

When exporting SAM3 with dynamic image dimensions using torch.export, the export fails with:

Could not guard on data-dependent expression Eq(u0*u1, 5184)
Caused by: F.multi_head_attention_forward (nn/functional.py:6475)

This happens because nn.MultiheadAttention calls F.multi_head_attention_forward, which has internal guards on sequence length for:

  • Shape validation of attention masks
  • Fast-path vs slow-path selection
  • Self-attention vs cross-attention detection

When the sequence length is symbolic (e.g., H*W where H/W are dynamic), these guards cannot be statically evaluated, causing export to fail.

Note: The commonly suggested workaround sdpa_kernel([SDPBackend.MATH]) does NOT work for this case because the guard failure happens before scaled_dot_product_attention is called - it occurs in F.multi_head_attention_forward's shape validation code.

Solution

ExportFriendlyMultiheadAttention bypasses F.multi_head_attention_forward entirely by:

  1. Manually projecting Q, K, V using the same combined in_proj_weight
  2. Calling F.scaled_dot_product_attention directly
  3. Avoiding all shape validation guards

Also includes:

  • from_nn_mha() classmethod to create from existing nn.MultiheadAttention with weight copying
  • replace_mha_with_export_friendly() utility function to recursively replace all MHA modules in a model

Test Results

Replaced 61 MultiheadAttention modules

✓ eager_mode: PASS (model works after replacement)
✓ pt2_export: PASS (with strict=False, dynamic H/W works!)
  Graph has 6660 nodes

Usage

from sam3.model.model_misc import replace_mha_with_export_friendly

# Load model
model = build_sam3_image_model(...)

# Replace all MHA modules before export
num_replaced = replace_mha_with_export_friendly(model, verbose=True)
print(f"Replaced {num_replaced} modules")

# Now export with dynamic shapes works
from torch.export import export
from torch.export.dynamic_shapes import Dim

height_mult = Dim("height_mult", min=1, max=6)
width_mult = Dim("width_mult", min=1, max=6)

exported = export(
    model,
    args=(images, ...),
    dynamic_shapes={"images": {2: 336 * height_mult, 3: 336 * width_mult}},
    strict=False,
)

Related Issues

…tensor

Create the box scale tensor directly on the target device instead of using
pin_memory().to(device, non_blocking=True). This enables:

- CPU-only inference (pin_memory requires CUDA)
- Apple MPS inference (pin_memory not supported)
- PT2 export without runtime patching

The scale tensor is always exactly 4 floats (16-32 bytes). For such a small
tensor, the pin_memory overhead likely exceeds any async transfer benefit.
Creating the tensor directly on device avoids the CPU→GPU transfer entirely.
This adds a custom MultiheadAttention implementation that bypasses
F.multi_head_attention_forward to enable torch.export with dynamic
shapes (e.g., variable image H/W).

The problem: nn.MultiheadAttention uses F.multi_head_attention_forward
which has internal guards on sequence length (e.g., Eq(seq_len, 5184))
that fail during torch.export because the sequence length is symbolic.

The solution: ExportFriendlyMultiheadAttention:
- Manually projects Q, K, V using the same combined in_proj_weight
- Calls F.scaled_dot_product_attention directly
- Avoids all shape validation guards in F.multi_head_attention_forward

Also adds replace_mha_with_export_friendly() utility function to
recursively replace all nn.MultiheadAttention modules in a model.

Related PyTorch issues:
- pytorch/pytorch#170127
- pytorch/pytorch#124502
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
scale = scale.view(1, 1, 4)
scale = torch.tensor(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removes memory pinning for cpu export. see #1

return super().forward(*args, **kwargs)


class ExportFriendlyMultiheadAttention(nn.Module):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while export now works, still need to validate this works when running the exported model compared to original eager mode

During torch.export with dynamic H/W dimensions, SymInt values cannot
be used as dict keys. These caches prevented dynamic shape export.

Changes:
- position_encoding.py: Remove (H, W) keyed cache in forward()
- decoder.py: Remove coord_cache dict lookup in _get_rpb_matrix()

The computation is cheap (just torch.arange) so always computing is
acceptable for export use cases.
@rbavery
Copy link
Member Author

rbavery commented Feb 7, 2026

duplciate of #3

@rbavery rbavery closed this Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

1 participant